/[finished]Assignment_2_word2vec/sgd.py
Python | 133 lines | 76 code | 28 blank | 29 comment | 19 complexity | f7146c6801df0558c149d1306758cf81 MD5 | raw file
- #!/usr/bin/env python
- # Save parameters every a few SGD iterations as fail-safe
- SAVE_PARAMS_EVERY = 5000
- import pickle
- import glob
- import random
- import numpy as np
- import os.path as op
- def load_saved_params():
- """
- A helper function that loads previously saved parameters and resets
- iteration start.
- """
- st = 0
- for f in glob.glob("saved_params_*.npy"):
- iter = int(op.splitext(op.basename(f))[0].split("_")[2])
- if (iter > st):
- st = iter
- if st > 0:
- params_file = "saved_params_%d.npy" % st
- state_file = "saved_state_%d.pickle" % st
- params = np.load(params_file)
- with open(state_file, "rb") as f:
- state = pickle.load(f)
- return st, params, state
- else:
- return st, None, None
- def save_params(iter, params):
- params_file = "saved_params_%d.npy" % iter
- np.save(params_file, params)
- with open("saved_state_%d.pickle" % iter, "wb") as f:
- pickle.dump(random.getstate(), f)
- def sgd(f, x0, step, iterations, postprocessing=None, useSaved=False,
- PRINT_EVERY=10):
- """ Stochastic Gradient Descent
- Implement the stochastic gradient descent method in this function.
- Arguments:
- f -- the function to optimize, it should take a single
- argument and yield two outputs, a loss and the gradient
- with respect to the arguments
- x0 -- the initial point to start SGD from
- step -- the step size for SGD
- iterations -- total iterations to run SGD for
- postprocessing -- postprocessing function for the parameters
- if necessary. In the case of word2vec we will need to
- normalize the word vectors to have unit length.
- PRINT_EVERY -- specifies how many iterations to output loss
- Return:
- x -- the parameter value after SGD finishes
- """
- # Anneal learning rate every several iterations
- ANNEAL_EVERY = 20000
- if useSaved:
- start_iter, oldx, state = load_saved_params()
- if start_iter > 0:
- x0 = oldx
- step *= 0.5 ** (start_iter / ANNEAL_EVERY)
- if state:
- random.setstate(state)
- else:
- start_iter = 0
- x = x0
- if not postprocessing:
- postprocessing = lambda x: x
- exploss = None
- for iter in range(start_iter + 1, iterations + 1):
- # You might want to print the progress every few iterations.
- loss = None
- ### YOUR CODE HERE
- loss,gd = f(x)
- x = x - step*gd
- x = postprocessing(x)
- ### END YOUR CODE
- x = postprocessing(x)
- if iter % PRINT_EVERY == 0:
- if not exploss:
- exploss = loss
- else:
- exploss = .95 * exploss + .05 * loss
- print("iter %d: %f" % (iter, exploss))
- if iter % SAVE_PARAMS_EVERY == 0 and useSaved:
- save_params(iter, x)
- if iter % ANNEAL_EVERY == 0:
- step *= 0.5
- return x
- def sanity_check():
- quad = lambda x: (np.sum(x ** 2), x * 2)
- print("Running sanity checks...")
- t1 = sgd(quad, 0.5, 0.01, 1000, PRINT_EVERY=100)
- print("test 1 result:", t1)
- assert abs(t1) <= 1e-6
- t2 = sgd(quad, 0.0, 0.01, 1000, PRINT_EVERY=100)
- print("test 2 result:", t2)
- assert abs(t2) <= 1e-6
- t3 = sgd(quad, -1.5, 0.01, 1000, PRINT_EVERY=100)
- print("test 3 result:", t3)
- assert abs(t3) <= 1e-6
- print("-" * 40)
- print("ALL TESTS PASSED")
- print("-" * 40)
- if __name__ == "__main__":
- sanity_check()